"""
## Preparation for LaVCa step1 and step2: Extract CLIP-Vision features for
## voxel-wise encoding models (step1)
## and searching optimal images (step2)

## Execution example:
# For NSD images
python3 -m LaVCa.step0_feature_extracting \
    --modality image \
    --modality_hparams default \
    --model_names CLIP-ViT-B-32 \
    --model_source transformers \
    --dataset_name nsd \
    --devices 0

# For OpenImages
python3 -m step0_feature_extracting \
    --modality image \
    --modality_hparams default \
    --model_names CLIP-ViT-B-32 \
    --model_source transformers \
    --dataset_name OpenImages \
    --dataset_path ./data/OpenImages \
    --devices 0
    
"""

import argparse
from tqdm import tqdm
from utils.utils import resize_images, load_frames
import os
import numpy as np
import torch
from PIL import Image
import gc
from transformers import AutoProcessor, CLIPVisionModelWithProjection

# Dictionary to store model information
model_dict_image = {
    "CLIP-ViT-B-32": (
        AutoProcessor, 
        CLIPVisionModelWithProjection, 
        "openai/clip-vit-base-patch32", 
        "openai/clip-vit-base-patch32"
    ),
}

def convert_image_to_embedding(
    frames_ds, 
    image_processor, 
    model_name, 
    model, 
    dataset_name, 
    devices, 
    first_device,
    emb_save_path,
    dirname=None,       # Directory name used when the dataset is not nsd
    movname=None,       # File name used when the dataset is nsd
    batch_size=1024
    ):
    """
    Converts images to embeddings.
    If dataset_name is 'nsd', it handles NSD-specific processing:
    - Expects frames_ds to be a list of lists of frame paths.
    - Averages embeddings along the batch dimension and saves them as a single file.

    Otherwise, it processes images individually:
    - Expects frames_ds to be a one-dimensional list of frame paths.
    - Computes embeddings in batches, then saves each one as a .npy file.
    """
    
    # ---- for nsd ----
    if dataset_name == "nsd":
        embs_all = []
        # Expects frames_ds to be a list of lists of frame paths
        for frame_idx, frame_paths in enumerate(tqdm(frames_ds)):
            images = [Image.open(frame_path) for frame_path in frame_paths]
            prompts = ["dummy" for _ in images]
            inputs = image_processor(prompts, images, return_tensors='pt')
            inputs = inputs["pixel_values"]

            with torch.no_grad():
                # Device assignment (for single GPU/CPU)
                if len(devices) == 1:
                    if devices[0] >= 0:
                        inputs = inputs.to(f"cuda:{devices[0]}")
                    else:
                        inputs = inputs.to("cpu")

                embs_batch = model(inputs)
                embs_batch = embs_batch.image_embeds

            # Take the mean in the batch dimension, then flatten
            embs_batch = embs_batch.cpu().detach().numpy()
            embs_batch = embs_batch.mean(axis=0).flatten()
            embs_all.append(embs_batch)

            del inputs, embs_batch, images
            gc.collect()
            torch.cuda.empty_cache()

        # Finally, save all features from all frames in one file
        embs_all = np.array(embs_all)
        print("Final embedding shape:", embs_all.shape)

        embs_all = embs_all.astype(np.float16)
        layer_save_path = f"{emb_save_path}/layerLast/"
        os.makedirs(layer_save_path, exist_ok=True)

        if movname is None:
            movname = "nsd_embs"  # Default name if movname is not provided

        save_path = f"{layer_save_path}/{movname}.npy"
        np.save(save_path, embs_all)

    # ---- Normal processing ----
    else:
        # Expects frames_ds to be a one-dimensional list of frame paths
        num_frames = len(frames_ds)
        for batch_start in tqdm(range(0, num_frames, batch_size)):
            batch_end = min(batch_start + batch_size, num_frames)
            batch_paths = frames_ds[batch_start:batch_end]
            
            images = [Image.open(frame_path) for frame_path in batch_paths]
            prompts = ["dummy" for _ in images]
            inputs = image_processor(prompts, images, return_tensors='pt')
            inputs = inputs["pixel_values"]

            with torch.no_grad():
                # Device assignment (for single GPU/CPU)
                if len(devices) == 1:
                    if devices[0] >= 0:
                        inputs = inputs.to(f"cuda:{devices[0]}")
                    else:
                        inputs = inputs.to("cpu")

                embs_batch = model(inputs)
                embs_batch = embs_batch.image_embeds

            for frame_path, emb in zip(batch_paths, embs_batch):
                emb = emb.flatten().cpu().detach().numpy().astype(np.float16)
                
                # Specify save directory
                # Here we assume saving under layerLast/dirname
                layer_save_path = f"{emb_save_path}/layerLast/{dirname}" if dirname else f"{emb_save_path}/layerLast"
                os.makedirs(layer_save_path, exist_ok=True)

                frame_name = os.path.basename(frame_path).replace(".png", "").replace(".jpg", "")
                save_path = f"{layer_save_path}/{frame_name}.npy"
                np.save(save_path, emb)

            del inputs, embs_batch, images
            gc.collect()
            torch.cuda.empty_cache()


def embedding_maker_image(
    model_name, 
    stride, 
    dataset_name, 
    dataset_path, 
    devices, 
    resize=False
):
    """
    Main function to handle image embeddings. 
    - Loads models and processors based on model_name.
    - Resizes images if requested.
    - Calls convert_image_to_embedding or convert_image_to_embedding_for_fmri based on the dataset_name.
    """
    if devices[0] >= 0:
        first_device = f"cuda:{devices[0]}"
    else:
        first_device = "cpu"
        
    print("Selected model:", model_name)
    processor_cls, model_cls, processor_path, model_path = model_dict_image[model_name]

    # Set emb_save_path and dataset_path if dataset_name is nsd
    if dataset_name == "nsd":
        emb_save_path = f"./data/stim_features/{dataset_name}/image/default/{model_name}"
        dataset_path = "./data/nsd"
    else:
        emb_save_path = f"./data/stim_features/{dataset_name}/image/default/{model_name}"
    os.makedirs(emb_save_path, exist_ok=True)
    
    # Load local (or cached) processor and model
    saved_processor_path = f"./data/model_ckpts/image/{processor_path}"
    saved_model_path = f"./data/model_ckpts/image/{model_path}"
    processor = processor_cls.from_pretrained(saved_processor_path)
    model = model_cls.from_pretrained(saved_model_path, device_map=first_device)

    # Default image size (typical for CLIP is 224x224)
    size = 224
    try:
        size = model.config.image_size
    except:
        if not size:
            raise NotImplementedError("No image size defined in the model config.")
            
    if isinstance(size, list):
        width, height = size[0], size[1]
    else:
        width, height = size, size
    print(f"Image size: {width}x{height}px")
        
    target_directory = f'{dataset_path}/frames_{width}x{height}px'
    
    # Resize images if requested
    if resize:
        source_directory = f'{dataset_path}/frames'
        resize_images(source_directory, target_directory, width, height)

    # Load all frames from the target directory
    frames_all = load_frames(target_directory, dataset_name)

    # Process each video or image collection
    for movname, frame_paths in frames_all.items():
        save_check_path = f"{emb_save_path}/{movname}.npy"
        if os.path.exists(save_check_path):
            print(f"Already exists: {movname} with {len(frame_paths)} frames")
            continue
        
        print(f"Now processing: {movname} with {len(frame_paths)} frames")
        
        # Some datasets (e.g., "nsd", "OpenImages", "CC3M") do not need downsampling

        # Otherwise, use the general image embedding function
        convert_image_to_embedding(
            frame_paths, 
            processor, 
            model_name, 
            model, 
            dataset_name, 
            devices, 
            first_device, 
            emb_save_path, 
            movname
        )


def main():
    """
    Main entry point for this script.
    Parses command-line arguments and calls embedding_maker_image.
    """
    parser = argparse.ArgumentParser(description="Script for encoding images into embeddings.")
    parser.add_argument("--modality", type=str, default="image", help="Type of modality (default: image)")
    parser.add_argument("--modality_hparams", type=str, default="default", help="Hyperparameters for modality (default: default)")
    parser.add_argument("--model_names", type=str, default="CLIP-ViT-B-32", help="Name of the model to use")
    parser.add_argument("--model_source", type=str, default="transformers", help="Source of the model (default: transformers)")
    parser.add_argument("--dataset_name", type=str, required=True, help="Name of the dataset (e.g., nsd, OpenImages, etc.)")
    parser.add_argument("--dataset_path", type=str, default=None, help="Path to the dataset if needed")
    parser.add_argument("--devices", type=str, default="0", help="Comma-separated list of device IDs (e.g., '0' or '0,1')")
    parser.add_argument("--stride", type=int, default=1, help="Stride for downsampling frames (default: 1)")
    parser.add_argument("--resize", action="store_true", help="Whether to resize images before embedding")

    args = parser.parse_args()

    # Convert the devices argument (e.g., "0,1" -> [0,1])
    devices = list(map(int, args.devices.split(",")))

    # If dataset_path is not provided, use a default path (optional handling)
    if not args.dataset_path:
        args.dataset_path = f"./data/{args.dataset_name}"

    embedding_maker_image(
        model_name=args.model_names,
        stride=args.stride,
        dataset_name=args.dataset_name,
        dataset_path=args.dataset_path,
        devices=devices,
        resize=args.resize
    )


if __name__ == "__main__":
    main()
